{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/pissa/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:20<00:00,  5.08s/it]\n",
      "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is a scalar.\n",
      "\n",
      "In this tutorial, you will discover how to implement an attention function in Keras.\n",
      "\n",
      "After completing this tutorial, you will know:\n",
      "\n",
      "- How to define an attention function in Keras.\n",
      "- How to use an attention function in a Keras model.\n",
      "- How to use an attention function in a Keras model for sequence classification.\n",
      "\n",
      "Let’s get started.\n",
      "\n",
      "Tutorial Overview\n",
      "\n",
      "This tutorial is divided into three parts;\n"
     ]
    }
   ],
   "source": [
    "model_name = \"/data/mfx/deepseek-ai/DeepSeek-V2-Lite\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')\n",
    "model.generation_config = GenerationConfig.from_pretrained(model_name)\n",
    "model.generation_config.pad_token_id = model.generation_config.eos_token_id\n",
    "\n",
    "text = \"An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is\"\n",
    "inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)\n",
    "result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_heads=model.config.num_attention_heads\n",
    "qk_nope_head_dim = model.config.qk_nope_head_dim\n",
    "qk_rope_head_dim = model.config.qk_rope_head_dim\n",
    "q_head_dim = qk_nope_head_dim + qk_rope_head_dim\n",
    "v_head_dim = model.config.v_head_dim\n",
    "kv_lora_rank = model.config.kv_lora_rank\n",
    "hidden_size = model.config.hidden_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "8809it [07:02, 20.87it/s]\n"
     ]
    }
   ],
   "source": [
    "for name,module in tqdm(model.named_modules()):\n",
    "    if name.endswith(\"self_attn\"):\n",
    "        # Orthogonal q_proj and k_up_weight\n",
    "        q_proj = module.q_proj.weight.data.to(torch.float32)\n",
    "        q_proj = q_proj.view(num_heads, q_head_dim, hidden_size)\n",
    "        q_nope, q_rope = torch.split(q_proj, [qk_nope_head_dim, qk_rope_head_dim], dim=1) # q_nope(num_head, head_dim, hidden_size)\n",
    "        kv_b_proj = module.kv_b_proj.weight.data.to(torch.float32)\n",
    "        kv_b_proj = kv_b_proj.view(num_heads, qk_nope_head_dim+v_head_dim, kv_lora_rank)\n",
    "        k_nope, value_states = torch.split(kv_b_proj, [qk_nope_head_dim, v_head_dim], dim=1) # k_nope(num_head, head_dim, latent_dim),  value_states(num_head, head_dim, latent_dim)\n",
    "        q_t_k_up = torch.einsum(\"hdD,hdL->hDL\",q_nope, k_nope) # (num_head, head_dim, latent_dim), rank<=head_dim\n",
    "        U,S,V = torch.svd_lowrank(q_t_k_up, qk_nope_head_dim, niter=qk_nope_head_dim) # U(num_head, hidden_size, head_dim), S(num_head, head_dim), V(num_head, latent_dim, head_dim)\n",
    "        q_nope = torch.einsum('hDd,hd->hdD',U,torch.sqrt(S)) # (num_head, head_dim, hidden_size)\n",
    "        k_nope = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (num_head, head_dim, latent_dim)\n",
    "        module.q_proj.weight.data = torch.cat([q_nope, q_rope],dim=1).reshape(num_heads*q_head_dim, hidden_size).contiguous().to(torch.bfloat16)\n",
    "        \n",
    "        \n",
    "        # Orthogonal o_proj and v_up_weight\n",
    "        o_proj = module.o_proj.weight.data.to(torch.float32)\n",
    "        o_proj = o_proj.view(hidden_size, num_heads, v_head_dim).transpose(0,1) # (num_head, hidden_size, head_dim)\n",
    "        o_v_up = torch.einsum(\"hDd,hdL->hDL\",o_proj, value_states) # (num_head, hidden_size, latent_dim), rank<=head_dim\n",
    "        U,S,V = torch.svd_lowrank(o_v_up, v_head_dim, niter=v_head_dim) # U(num_head, hidden_size, head_dim), S(num_head, head_dim), V(num_head, latent_dim, head_dim)\n",
    "        o_proj = torch.einsum('hDd,hd->Dhd',U,torch.sqrt(S)) # (hidden_size, num_head, head_dim)\n",
    "        value_states = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (num_head, head_dim, latent_dim)\n",
    "        module.kv_b_proj.weight.data = torch.cat([k_nope, value_states],dim=1).reshape(num_heads*(qk_nope_head_dim+v_head_dim), kv_lora_rank).contiguous().to(torch.bfloat16)\n",
    "        module.o_proj.weight.data = o_proj.reshape(hidden_size, (num_heads * v_head_dim)).contiguous().to(torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is a scalar, which is the cosine similarity between the query and the key.\n",
      "\n",
      "The attention function is used in the Transformer architecture, which is a neural network architecture that is used for natural language processing tasks such as machine translation and text classification. The attention function is used to compute the importance of each word in the input sequence, and to generate a representation of the input sequence that is more useful for the task at hand.\n",
      "\n",
      "The attention function is defined as follows:\n",
      "\n",
      "$$\n",
      "attention\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)\n",
    "result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"DeepSeek-V2-Lite_transMLA\")\n",
    "#model.push_to_hub(\"fxmeng/DeepSeek-V2-Lite_transMLA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.save_pretrained(\"DeepSeek-V2-Lite_transMLA\")\n",
    "#tokenizer.push_to_hub(\"fxmeng/DeepSeek-V2-Lite_transMLA\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pissa",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
